#%%
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from collections import defaultdict
from pathlib import Path
from scipy.stats import norm, t as tdist

from network import *
from utils  import *

#%%
# Experiment settings

d = 5
n_list       = [200]
n_reps       = 2     # 50
device       = "cpu"
seed0        = 12345
source  = "normal"   
target  = "t" 

Cs = [0, np.inf]

# Set base directory
try:
    script_path = Path(__file__).resolve()
    project_root = script_path.parent
    base_dir   = project_root / "output"
except NameError:
    project_root = Path.cwd()
    base_dir   = project_root / "output"



fig_dir  = os.path.join(base_dir, f"figures-TNN-d{d}")
os.makedirs(fig_dir, exist_ok=True)

#%%
# sieved-TNN estimator, uses pre-sampled X_np, Y_np

def run_tnn(n, C, r, X_np, Y_np):
    # Generate Data
    X = torch.from_numpy(X_np.copy()).float().to(device)
    Y = torch.from_numpy(Y_np.copy()).float().to(device)
    max_normX = np.linalg.norm(X_np, axis=1).max()
    max_normY = np.linalg.norm(Y_np, axis=1).max()
    
    # Train Model
    R = (max_normX + C*max_normY)
    epoch1 = max(int(10000/n),50)

    model = TwoHiddenTanhFNN(d, 10, 1) 
    brenier_model, tr_loss, val_loss = train_brenier(
        x=X, y=Y,
        batch_size=64, model=model, val_ratio=0.1,
        sieved=R,
        learning_rate1=0.005, learning_rate2=0.001,
        num_epochs1=epoch1, num_epochs2=300,
        num_epochs_convex_conjugate=300,
        bar=True, save=False
    )

    # Test Model
    N_test = int(1e6)
    if source == "normal":
        X_test_np = np.random.multivariate_normal(np.zeros(d), np.eye(d), size=N_test)
    if source == "t":  # t
        X_test_np = generate_multivariate_t(df=6, d=d, n_samples=N_test)
    if source == "uniform":
        X_test_np = np.random.uniform(0, 1, size=(N_test, d))

    X_test = torch.from_numpy(X_test_np.copy()).float().to(device)

    Y_hats = []
    for i in range(0, N_test, 2000):
        Xbatch = X_test[i:i+2000]
        Y_batch = compute_gradients(brenier_model, Xbatch)
        Y_hats.append(Y_batch.cpu())
    Y_hat = torch.cat(Y_hats, dim=0)

    if source == "normal":
        F = norm.cdf
    if source == "t":
        F = lambda x: tdist.cdf(x, df=6)
    if source == "uniform":
        F = lambda u:u    # For U[0,1]
    U = F(X_test_np)

    if target == "normal":
        Ginv = norm.ppf
        Y_true_np = Ginv(U)
    if target == "uniform":
        Ginv = lambda u: u  # identity for U[0,1]
        Y_true_np = Ginv(U)
    if target == "t":
        Ginv = lambda u: tdist.ppf(u, df=6)
        Y_true_np = Ginv(U)
        
    cov = np.cov(Y_true_np, rowvar=False)
    var_target = np.trace(cov)

    Y_true = torch.from_numpy(Y_true_np.copy()).float().to(device)
    var_target = torch.as_tensor(var_target)

    # L2‐UVP
    mse = torch.mean(torch.sum((Y_hat - Y_true)**2, dim=1))
    l2_UVP = (mse/var_target).item()

    return n, C, r, l2_UVP, tr_loss, val_loss

#%%
results = []

for n in n_list:
    for r in range(n_reps):
        ### Set Seed
        np.random.seed(seed0 + r)
        torch.manual_seed(seed0 + r)

        ### Generate Data
        # source X
        if source == "normal":
            X_np = np.random.multivariate_normal(np.zeros(d), np.eye(d), size=n)
        if source == "t":  
            X_np = generate_multivariate_t(df=6, d=d, n_samples=n)
        if source == "uniform":
            X_np = np.random.uniform(0, 1, size=(n,d))

        # target Y
        if target == "normal":
            Y_np = np.random.multivariate_normal(np.zeros(d), np.eye(d), size=n)
        if target == "uniform": 
            Y_np = np.random.uniform(0, 1, size=(n,d))
        if target == "t":
            Y_np = generate_multivariate_t(df=6, d=d, n_samples=n)
        
        # Training 
        for C in Cs:
            results.append(run_tnn(n, C, r, X_np, Y_np))



#%%
data = defaultdict(lambda: {'train': [], 'val': []})
for rec in results:
    n, C, r, _l2, tr, vl = rec
    data[(n, C)]['train'].append(np.array(tr))
    data[(n, C)]['val'  ].append(np.array(vl))

log_n = np.log(n_list)
mid   = (len(Cs)-1)/2

# Plot train & val loss 
for n0 in n_list:
    fig, axes = plt.subplots(1, 2, figsize=(12, 4), sharex=True)
    for ax, kind in zip(axes, ['train','val']):
        epochs = np.arange(1, len(data[(n0, Cs[0])][kind][0]) + 1)
        for C in Cs:
            curves = np.stack(data[(n0, C)][kind], axis=0)
            mu     = curves.mean(axis=0)
            sigma  = curves.std(axis=0, ddof=1)
            ax.plot(epochs, mu, label=f"C={C}")
            ax.fill_between(epochs, mu-sigma, mu+sigma, alpha=0.2)
        ax.relim(); ax.autoscale_view()
        ax.set_xlabel("Epoch")
        ax.set_ylabel(f"{kind.capitalize()} Loss")
        ax.set_title(f"{kind.capitalize()} (n={n0})")
        ax.grid(ls='--', alpha=0.3)
        ax.legend()
    fig.suptitle(f"TNN Curves n={n0}\n(src={source},tgt={target})", fontsize=14)
    plt.tight_layout(rect=[0,0,1,0.95])
    fig.savefig(os.path.join(
        fig_dir, f"loss_curves_n{n0}_{source}_{target}_d{d}.png"
    ), dpi=300)
    plt.show()


# OT‐map L2-UVP
l2_data = defaultdict(list) 
for rec in results:
    # rec is (n, C, r, l2_UVP, tr_loss, val_loss)
    n, C, _, l2_UVP, *_ = rec
    l2_data[(n,C)].append(l2_UVP)

plt.figure(figsize=(6,4))
for idx, C in enumerate(Cs):
    offset = (idx - mid) * 0.04
    x = np.exp(log_n + offset)

    means, stds = [], []
    for n in n_list:
        arr = np.array(l2_data[(n,C)])
        means.append(arr.mean())
        stds .append(arr.std(ddof=1))

    plt.errorbar(
        x, means, yerr=stds,
        fmt='o-', capsize=5, elinewidth=1, markeredgewidth=1.5,
        label=f"C={C}"
    )

plt.xscale('log')
plt.yscale('log')
plt.xticks(n_list, n_list)
plt.xlabel("Number of samples $n$")
plt.ylabel("L2-UVP")
plt.title(f"L2-UVP vs $n$  (src={source}, tgt={target})")
plt.legend(ncol=1, frameon=False)
# plt.grid(True, which='both', ls='--', alpha=0.5)
plt.tight_layout()
plt.savefig(os.path.join(fig_dir, f"l2_UVP_vs_n_{source}_{target}_jittered_d{d}.png"), dpi=300)
plt.show()